package main

import (
	"bytes"
	"embed"
	"fmt"
	"io"
	"io/fs"
	"log"
	"net/http"
	"os"
	"path"
	"path/filepath"
	"strings"
	"time"

	"github.com/hillu/go-yara/v4"
	"github.com/jessevdk/go-flags"
)

const (
	RulesURI             = "https://raw.githubusercontent.com/jvoisin/php-malware-finder/master/php-malware-finder/data"
	RulesFile            = "php.yar"
	ScanMaxDuration      = time.Duration(60)
	TooShort             = "TooShort"
	TooShortMaxLines     = 2
	TooShortMinChars     = 300
	DangerousMatchWeight = 2
	DangerousMinScore    = 3
	FileBufferSize       = 32 * 1024 // 32KB
	YaraMaxThreads       = 32
	TempDirPrefix        = "pmf-"
	ExitCodeOk           = 0
	ExitCodeWithMatches  = 255
	ExitCodeWithError    = 1
)

var (
	args struct { // command-line arguments specs using github.com/jessevdk/go-flags
		RulesDir      string   `short:"r" long:"rules-dir" description:"Alternative rules location (default: embedded rules)"`
		ShowAll       bool     `short:"a" long:"show-all" description:"Display all matched rules"`
		Fast          bool     `short:"f" long:"fast" description:"Enable YARA's fast mode"`
		RateLimit     int      `short:"R" long:"rate-limit" description:"Max. filesystem ops per second, 0 for no limit" default:"0"`
		Verbose       bool     `short:"v" long:"verbose" description:"Verbose mode"`
		Workers       int      `short:"w" long:"workers" description:"Number of workers to spawn for scanning" default:"32"`
		LongLines     bool     `short:"L" long:"long-lines" description:"Check long lines"`
		ExcludeCommon bool     `short:"c" long:"exclude-common" description:"Do not scan files with common extensions"`
		ExcludeImgs   bool     `short:"i" long:"exclude-imgs" description:"Do not scan image files"`
		ExcludedExts  []string `short:"x" long:"exclude-ext" description:"Additional file extensions to exclude"`
		Update        bool     `short:"u" long:"update" description:"Update rules"`
		ShowVersion   bool     `short:"V" long:"version" description:"Show version number and exit"`
		Positional    struct {
			Target string
		} `positional-args:"yes"`
	}
	scanFlags        yara.ScanFlags
	stoppedWorkers   int
	lineFeed         = []byte{'\n'}
	dangerousMatches = map[string]struct{}{
		"PasswordProtection": {},
		"Websites":           {},
		"TooShort":           {},
		"NonPrintableChars":  {},
	}
	excludedDirs = [...]string{
		"/.git/", "/.hg/", "/.svn/", "/.CVS/",
	}
	excludedExts = map[string]struct{}{}
	commonExts   = [...]string{
		".js", ".coffee", ".map", ".min", ".css", ".less", // static files
		".zip", ".rar", ".7z", ".gz", ".bz2", ".xz", ".tar", ".tgz", // archives
		".txt", ".csv", ".json", ".rst", ".md", ".yaml", ".yml", // plain text
		".so", ".dll", ".bin", ".exe", ".bundle", // binaries
	}
	imageExts = [...]string{
		".png", ".jpg", ".jpeg", ".gif", ".svg", ".bmp", ".ico",
	}
	scannedFilesCount int
	rulesFiles        = [...]string{
		RulesFile, "whitelist.yar",
		"whitelists/custom.yar", "whitelists/drupal.yar", "whitelists/magento1ce.yar", "whitelists/magento2.yar",
		"whitelists/phpmyadmin.yar", "whitelists/prestashop.yar", "whitelists/symfony.yar", "whitelists/wordpress.yar",
	}
	tempDirPathPrefix = path.Join(os.TempDir(), TempDirPrefix)
	version           = "dev"

	//go:embed data/php.yar data/whitelist.yar data/whitelists
	data embed.FS
)

// handleError is a generic error handler which displays an error message to the user and exits if required.
func handleError(err error, desc string, isFatal bool) {
	if err != nil {
		if desc != "" {
			desc = " " + desc + ":"
		}
		log.Println("[ERROR]"+desc, err)
		if isFatal {
			os.Exit(ExitCodeWithError)
		}
	}
}

// writeRulesFiles copies the rules from the content of a `fs.FS` to a temporary folder and
// returns its location.
func writeRulesFiles(content fs.FS) string {
	// create temporary folder structure
	rulesPath, err := os.MkdirTemp(os.TempDir(), TempDirPrefix)
	handleError(err, "unable to create temporary folder", true)
	err = os.Mkdir(path.Join(rulesPath, "whitelists"), 0755)
	handleError(err, "unable to create temporary subfolder", true)

	// write each YARA file to the disk
	for _, rulesFile := range rulesFiles {
		// read embedded content
		f, err := content.Open(path.Join("data", rulesFile))
		handleError(err, "unable to open embedded rule", true)
		ruleData, err := io.ReadAll(f)
		handleError(err, "unable to read rule content", true)

		// write to temporary file
		err = os.WriteFile(path.Join(rulesPath, rulesFile), ruleData, 0640)
		handleError(err, "unable to write rule to disk", true)
		err = f.Close()
		handleError(err, "unable to close rules file", false)
	}
	return rulesPath
}

// updateRules downloads latest YARA rules from phpmalwarefinder GitHub repository.
// Download location is either `args.RulesDir`, `/etc/phpmalwarefinder`, or the current directory.
func updateRules() {
	if strings.HasPrefix(args.RulesDir, tempDirPathPrefix) {
		handleError(fmt.Errorf("rules folder must be specified to update"), "", true)
	}
	if args.Verbose {
		log.Println("[DEBUG] updating ruleset")
	}

	downloadFile := func(uri string) []byte {
		resp, err := http.Get(uri)
		handleError(err, "unable to download rule", true)
		defer func() {
			err := resp.Body.Close()
			handleError(err, "unable to close response body", false)
		}()
		data, err := io.ReadAll(resp.Body)
		handleError(err, "unable to read response body", false)
		return data
	}
	writeFile := func(dst string, data []byte) {
		err := os.WriteFile(dst, data, 0640)
		handleError(err, "unable to write downloaded file", true)
	}

	rulesFiles := [...]string{
		RulesFile,
		"whitelist.yar", "whitelists/drupal.yar", "whitelists/magento1ce.yar",
		"whitelists/magento2.yar", "whitelists/phpmyadmin.yar", "whitelists/prestashop.yar",
		"whitelists/symfony.yar", "whitelists/wordpress.yar"}

	// download rules
	for _, rule := range rulesFiles {
		rulesUri := RulesURI + rule
		data := downloadFile(rulesUri)
		outPath := path.Join(args.RulesDir, rule)
		writeFile(outPath, data)
		log.Println("[INFO] updated rule:", rule)
	}
}

// fileStats takes a file path as argument and returns its lines and characters count.
// File reading is done using a 32KB buffer to minimize memory usage.
func fileStats(filepath string) (int, int, error) {
	f, err := os.Open(filepath)
	if err != nil {
		return 0, 0, err
	}
	defer func() {
		err := f.Close()
		handleError(err, "unable to close file", false)
	}()
	charCount, lineCount := 0, 0
	buf := make([]byte, FileBufferSize)
	for {
		chunkSize, err := f.Read(buf)
		charCount += chunkSize
		lineCount += bytes.Count(buf[:chunkSize], lineFeed)
		switch {
		case err == io.EOF:
			return charCount, lineCount, nil
		case err != nil:
			return charCount, lineCount, err
		}
	}
}

// makeScanner creates a YARA scanner with the appropriate options set.
func makeScanner(rules *yara.Rules) *yara.Scanner {
	scanner, err := yara.NewScanner(rules)
	handleError(err, "unable to create YARA scanner", true)
	scanner.SetFlags(scanFlags)
	scanner.SetTimeout(ScanMaxDuration)
	return scanner
}

// processFiles reads file paths from the `targets` channel, scans it, and writes matches to the `results` channel.
// Scanning is done using YARA `rules`, and using `fileStats` if `args.LongLines` is set.
// `ticker` is a `time.Time` object created with `time.Tick` used to throttle file scans to minimize impact on I/O.
func processFiles(rules *yara.Rules, targets <-chan string, results chan<- map[string][]yara.MatchRule, ticker <-chan time.Time) {
	scanner := makeScanner(rules)
	for target := range targets {
		<-ticker
		scannedFilesCount++
		result := map[string][]yara.MatchRule{target: {}}

		if args.LongLines {
			charCount, lineCount, err := fileStats(target)
			handleError(err, "unable to get file stats", false)
			if lineCount <= TooShortMaxLines && charCount >= TooShortMinChars {
				tooShort := yara.MatchRule{Rule: TooShort}
				result[target] = append(result[target], tooShort)
			}
		}

		var matches yara.MatchRules
		err := scanner.SetCallback(&matches).ScanFile(target)
		if err != nil {
			log.Println("[ERROR]", err)
			continue
		}
		for _, match := range matches {
			result[target] = append(result[target], match)
		}
		results <- result
	}
	stoppedWorkers++
	if stoppedWorkers == args.Workers {
		close(results)
	}
}

// scanDir recursively crawls `dirName`, and writes file paths to the `targets` channel.
// Files sent to `targets` are filtered according to their extensions.
func scanDir(dirName string, targets chan<- string, ticker <-chan time.Time) {
	visit := func(pathName string, fileInfo os.FileInfo, err error) error {
		<-ticker
		if !fileInfo.IsDir() {
			for _, dir := range excludedDirs {
				if strings.Contains(pathName, dir) {
					return nil
				}
			}
			fileExt := filepath.Ext(fileInfo.Name())
			if _, exists := excludedExts[fileExt]; !exists {
				targets <- pathName
			}
		}
		return nil
	}
	err := filepath.Walk(dirName, visit)
	handleError(err, "unable to complete target crawling", false)
	close(targets)
}

// loadRulesFile reads YARA rules from specified `fileName` and returns
// them in their compiled form.
func loadRulesFile(fileName string) (*yara.Rules, error) {
	var err error = nil
	// record working directory and move to rules location
	curDir, err := os.Getwd()
	if err != nil {
		return nil, fmt.Errorf("unable to determine working directory: %v", err)
	}
	ruleDir, ruleName := filepath.Split(fileName)
	err = os.Chdir(ruleDir)
	if err != nil {
		return nil, fmt.Errorf("unable to move to rules directory: %v", err)
	}

	// read file content
	data, err := os.ReadFile(ruleName)
	if err != nil {
		return nil, fmt.Errorf("unable to read rules file: %v", err)
	}

	// compile rules
	rules, err := yara.Compile(string(data), nil)
	if err != nil {
		return nil, fmt.Errorf("unable to load rules: %v", err)
	}

	// move back to working directory
	err = os.Chdir(curDir)
	if err != nil {
		return nil, fmt.Errorf("unable to move back to working directory: %v", err)
	}

	return rules, nil
}

func main() {
	startTime := time.Now()
	matchesFound := false
	_, err := flags.Parse(&args)
	if err != nil {
		os.Exit(ExitCodeWithError)
	}
	if args.ShowVersion {
		println(version)
		os.Exit(ExitCodeOk)
	}

	// check rules path
	if args.RulesDir == "" {
		args.RulesDir = writeRulesFiles(data)
	}
	if args.Verbose {
		log.Println("[DEBUG] rules directory:", args.RulesDir)
	}

	// update rules if required
	if args.Update {
		updateRules()
		os.Exit(ExitCodeOk)
	}

	// add custom excluded file extensions
	if args.ExcludeCommon {
		for _, commonExt := range commonExts {
			excludedExts[commonExt] = struct{}{}
		}
	}
	if args.ExcludeImgs || args.ExcludeCommon {
		for _, imgExt := range imageExts {
			excludedExts[imgExt] = struct{}{}
		}
	}
	for _, ext := range args.ExcludedExts {
		if string(ext[0]) != "." {
			ext = "." + ext
		}
		excludedExts[ext] = struct{}{}
	}
	if args.Verbose {
		extList := make([]string, len(excludedExts))
		i := 0
		for ext := range excludedExts {
			extList[i] = ext[1:]
			i++
		}
		log.Println("[DEBUG] excluded file extensions:", strings.Join(extList, ","))
	}

	// load YARA rules
	rulePath := path.Join(args.RulesDir, RulesFile)
	rules, err := loadRulesFile(rulePath)
	handleError(err, "", true)
	if args.Verbose {
		log.Println("[DEBUG] ruleset loaded:", rulePath)
	}

	// set YARA scan flags
	if args.Fast {
		scanFlags = yara.ScanFlags(yara.ScanFlagsFastMode)
	} else {
		scanFlags = yara.ScanFlags(0)
	}

	// check if requested threads count is not greater than YARA's MAX_THREADS
	if args.Workers > YaraMaxThreads {
		log.Printf("[WARNING] workers count too high, using %d instead of %d\n", YaraMaxThreads, args.Workers)
		args.Workers = YaraMaxThreads
	}

	// scan target
	if f, err := os.Stat(args.Positional.Target); os.IsNotExist(err) {
		handleError(err, "", true)
	} else {
		if args.Verbose {
			log.Println("[DEBUG] scan workers:", args.Workers)
			log.Println("[DEBUG] target:", args.Positional.Target)
		}
		if f.IsDir() { // parallelized folder scan
			// create communication channels
			targets := make(chan string)
			results := make(chan map[string][]yara.MatchRule)

			// rate limit
			var tickerRate time.Duration
			if args.RateLimit == 0 {
				tickerRate = time.Nanosecond
			} else {
				tickerRate = time.Second / time.Duration(args.RateLimit)
			}
			ticker := time.Tick(tickerRate)
			if args.Verbose {
				log.Println("[DEBUG] delay between fs ops:", tickerRate.String())
			}

			// start consumers and producer workers
			for w := 1; w <= args.Workers; w++ {
				go processFiles(rules, targets, results, ticker)
			}
			go scanDir(args.Positional.Target, targets, ticker)

			// read results
			matchCount := make(map[string]int)
			var keepListing bool
			var countedDangerousMatch bool
			for result := range results {
				for target, matchedSigs := range result {
					keepListing = true
					matchCount[target] = 0
					countedDangerousMatch = false
					for _, sig := range matchedSigs {
						matchCount[target] += DangerousMatchWeight
						if !countedDangerousMatch {
							if _, exists := dangerousMatches[sig.Rule]; exists {
								matchCount[target]++
							}
							countedDangerousMatch = true
						}
						if keepListing {
							log.Printf("[WARNING] match found: %s (%s)\n", target, sig.Rule)
							if !args.ShowAll {
								keepListing = false
							}
						}
					}
				}
			}
			for target, count := range matchCount {
				if count >= DangerousMinScore {
					log.Println("[WARNING] dangerous file found:", target)
					matchesFound = true
				}
			}
		} else { // single file mode
			scannedFilesCount++
			var matches yara.MatchRules
			scanner := makeScanner(rules)
			err := scanner.SetCallback(&matches).ScanFile(args.Positional.Target)
			handleError(err, "unable to scan target", true)
			for _, match := range matches {
				matchesFound = true
				log.Println("[WARNING] match found:", args.Positional.Target, match.Rule)
				if args.Verbose {
					for _, matchString := range match.Strings {
						log.Printf("[DEBUG] match string for %s: 0x%x:%s: %s\n", args.Positional.Target, matchString.Offset, matchString.Name, matchString.Data)
					}
				}
				if !args.ShowAll {
					break
				}
			}
		}
		if args.Verbose {
			endTime := time.Now()
			log.Printf("[DEBUG] scanned %d files in %s\n", scannedFilesCount, endTime.Sub(startTime).String())
		}
	}

	// delete temporary files
	if strings.HasPrefix(args.RulesDir, tempDirPathPrefix) {
		if args.Verbose {
			log.Println("[DEBUG] deleting temporary folder:", args.RulesDir)
		}
		err := os.RemoveAll(args.RulesDir)
		handleError(err, fmt.Sprintf("unable to delete temporary folder '%s'", args.RulesDir), false)
	}

	if matchesFound {
		os.Exit(ExitCodeWithMatches)
	}
	os.Exit(ExitCodeOk)
}
